{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNNs tutorial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we assume that we have the dynet module in your path.\n",
    "import dynet as dy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### An LSTM/RNN overview:\n",
    "\n",
    "An (1-layer) RNN can be thought of as a sequence of cells, $h_1,...,h_k$, where $h_i$ indicates the time dimenstion. \n",
    "\n",
    "Each cell $h_i$ has an input $x_i$ and an output $r_i$. In addition to $x_i$, cell $h_i$ receives as input also $r_{i-1}$.\n",
    "\n",
    "In a deep (multi-layer) RNN, we don't have a sequence, but a grid. That is we have several layers of sequences:\n",
    "\n",
    "* $h_1^3,...,h_k^3$ \n",
    "* $h_1^2,...,h_k^2$ \n",
    "* $h_1^1,...h_k^1$, \n",
    "\n",
    "Let $r_i^j$ be the output of cell $h_i^j$. Then:\n",
    "\n",
    "The input to $h_i^1$ is $x_i$ and $r_{i-1}^1$.\n",
    "\n",
    "The input to $h_i^2$ is $r_i^1$ and $r_{i-1}^2$,\n",
    "and so on.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The LSTM (RNN) Interface\n",
    "\n",
    "RNN / LSTM / GRU follow the same interface. We have a \"builder\" which is in charge of creating definining the parameters for the sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "pc = dy.ParameterCollection()\n",
    "NUM_LAYERS=2\n",
    "INPUT_DIM=50\n",
    "HIDDEN_DIM=10\n",
    "builder = dy.LSTMBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)\n",
    "# or:\n",
    "# builder = dy.SimpleRNNBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that when we create the builder, it adds the internal RNN parameters to the `ParameterCollection`.\n",
    "We do not need to care about them, but they will be optimized together with the rest of the network's parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "s0 = builder.initial_state()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1 = dy.vecInput(INPUT_DIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "s1=s0.add_input(x1)\n",
    "y1 = s1.output()\n",
    "# here, we add x1 to the RNN, and the output we get from the top is y (a HIDEN_DIM-dim vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10,)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y1.npvalue().shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2=s1.add_input(x1) # we can add another input\n",
    "y2=s2.output()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If our LSTM/RNN was one layer deep, y2 would be equal to the hidden state. However, since it is 2 layers deep, y2 is only the hidden state (= output) of the last layer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we were to want access to the all the hidden state (the output of both the first and the last layers), we could use the `.h()` method, which returns a list of expressions, one for each layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(expression 47/0, expression 62/0)\n"
     ]
    }
   ],
   "source": [
    "print(s2.h())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The same interface that we saw until now for the LSTM, holds also for the Simple RNN:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "all layers: (expression 19/0, expression 32/0)\n"
     ]
    }
   ],
   "source": [
    "# create a simple rnn builder\n",
    "rnnbuilder=dy.SimpleRNNBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)\n",
    "\n",
    "# initialize a new graph, and a new sequence\n",
    "rs0 = rnnbuilder.initial_state()\n",
    "\n",
    "# add inputs\n",
    "rs1 = rs0.add_input(x1)\n",
    "ry1 = rs1.output()\n",
    "print(\"all layers:\", s1.h())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(expression 17/0, expression 30/0, expression 19/0, expression 32/0)\n"
     ]
    }
   ],
   "source": [
    "print(s1.s())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To summarize, when calling `.add_input(x)` on an `RNNState` what happens is that the state creates a new RNN/LSTM column, passing it: \n",
    "1. the state of the current RNN column\n",
    "2. the input `x`\n",
    "\n",
    "The state is then returned, and we can call it's `output()` method to get the output `y`, which is the output at the top of the column. We can access the outputs of all the layers (not only the last one) using the `.h()` method of the state.\n",
    "\n",
    "**`.s()`** The internal state of the RNN may be more involved than just the outputs $h$. This is the case for the LSTM, that keeps an extra \"memory\" cell, that is used when calculating $h$, and which is also passed to the next column.  To access the entire hidden state, we use the `.s()` method. \n",
    "\n",
    "The output of `.s()` differs by the type of RNN being used. For the simple-RNN, it is the same as `.h()`. For the LSTM, it is more involved.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RNN h: (expression 70/0, expression 72/0)\n",
      "RNN s: (expression 70/0, expression 72/0)\n",
      "LSTM h: (expression 19/0, expression 32/0)\n",
      "LSTM s: (expression 17/0, expression 30/0, expression 19/0, expression 32/0)\n"
     ]
    }
   ],
   "source": [
    "rnn_h  = rs1.h()\n",
    "rnn_s  = rs1.s()\n",
    "print(\"RNN h:\", rnn_h)\n",
    "print(\"RNN s:\", rnn_s)\n",
    "\n",
    "\n",
    "lstm_h = s1.h()\n",
    "lstm_s = s1.s()\n",
    "print(\"LSTM h:\", lstm_h)\n",
    "print(\"LSTM s:\", lstm_s\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the LSTM has two extra state expressions (one for each hidden layer) before the outputs h."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extra options in the RNN/LSTM interface\n",
    "\n",
    "**Stack LSTM** The RNN's are shaped as a stack: we can remove the top and continue from the previous state.\n",
    "This is done either by remembering the previous state and continuing it with a new `.add_input()`, or using\n",
    "we can access the previous state of a given state using the `.prev()` method of state.\n",
    "\n",
    "**Initializing a new sequence with a given state** When we call `builder.initial_state()`, we are assuming the state has random /0 initialization. If we want, we can specify a list of expressions that will serve as the initial state. The expected format is the same as the results of a call to `.final_s()`. TODO: this is not supported yet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2=s1.add_input(x1)\n",
    "s3=s2.add_input(x1)\n",
    "s4=s3.add_input(x1)\n",
    "\n",
    "# let's continue s3 with a new input.\n",
    "s5=s3.add_input(x1)\n",
    "\n",
    "# we now have two different sequences:\n",
    "# s0,s1,s2,s3,s4\n",
    "# s0,s1,s2,s3,s5\n",
    "# the two sequences share parameters.\n",
    "\n",
    "assert(s5.prev() == s3)\n",
    "assert(s4.prev() == s3)\n",
    "\n",
    "s6=s3.prev().add_input(x1)\n",
    "# we now have an additional sequence:\n",
    "# s0,s1,s2,s6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(expression 207/0, expression 222/0)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s6.h()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(expression 205/0, expression 220/0, expression 207/0, expression 222/0)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s6.s()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Aside: memory efficient transduction\n",
    "The `RNNState` interface is convenient, and allows for incremental input construction.\n",
    "However, sometimes we know the sequence of inputs in advance, and care only about the sequence of\n",
    "output expressions. In this case, we can use the `add_inputs(xs)` method, where `xs` is a list of Expression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[expression 226/0, expression 230/0, expression 234/0] [(expression 224/0, expression 226/0), (expression 228/0, expression 230/0), (expression 232/0, expression 234/0)]\n"
     ]
    }
   ],
   "source": [
    "state = rnnbuilder.initial_state()\n",
    "xs = [x1,x1,x1]\n",
    "states = state.add_inputs(xs)\n",
    "outputs = [s.output() for s in states]\n",
    "hs =      [s.h() for s in states]\n",
    "print(outputs, hs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is convenient.\n",
    "\n",
    "What if we do not care about `.s()` and `.h()`, and do not need to access the previous vectors? In such cases\n",
    "we can use the `transduce(xs)` method instead of `add_inputs(xs)`.\n",
    "`transduce` takes in a sequence of `Expression`s, and returns a sequence of `Expression`s.\n",
    "As a consequence of not returning `RNNState`s, `trnasduce` is much more memory efficient than `add_inputs` or a series of calls to `add_input`.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[expression 238/0, expression 242/0, expression 246/0]\n"
     ]
    }
   ],
   "source": [
    "state = rnnbuilder.initial_state()\n",
    "xs = [x1,x1,x1]\n",
    "outputs = state.transduce(xs)\n",
    "print(outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### Character-level LSTM\n",
    "\n",
    "Now that we know the basics of RNNs, let's build a character-level LSTM language-model.\n",
    "We have a sequence LSTM that, at each step, gets as input a character, and needs to predict the next character."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from collections import defaultdict\n",
    "from itertools import count\n",
    "import sys\n",
    "\n",
    "LAYERS = 1\n",
    "INPUT_DIM = 50 \n",
    "HIDDEN_DIM = 50  \n",
    "\n",
    "characters = list(\"abcdefghijklmnopqrstuvwxyz \")\n",
    "characters.append(\"<EOS>\")\n",
    "\n",
    "int2char = list(characters)\n",
    "char2int = {c:i for i,c in enumerate(characters)}\n",
    "\n",
    "VOCAB_SIZE = len(characters)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "pc = dy.ParameterCollection()\n",
    "\n",
    "\n",
    "srnn = dy.SimpleRNNBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, pc)\n",
    "lstm = dy.LSTMBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, pc)\n",
    "\n",
    "# add parameters for the hidden->output part for both lstm and srnn\n",
    "params_lstm = {}\n",
    "params_srnn = {}\n",
    "for params in [params_lstm, params_srnn]:\n",
    "    params[\"lookup\"] = pc.add_lookup_parameters((VOCAB_SIZE, INPUT_DIM))\n",
    "    params[\"R\"] = pc.add_parameters((VOCAB_SIZE, HIDDEN_DIM))\n",
    "    params[\"bias\"] = pc.add_parameters((VOCAB_SIZE))\n",
    "\n",
    "# return compute loss of RNN for one sentence\n",
    "def do_one_sentence(rnn, params, sentence):\n",
    "    # setup the sentence\n",
    "    dy.renew_cg()\n",
    "    s0 = rnn.initial_state()\n",
    "    \n",
    "    \n",
    "    R = params[\"R\"]\n",
    "    bias = params[\"bias\"]\n",
    "    lookup = params[\"lookup\"]\n",
    "    sentence = [\"<EOS>\"] + list(sentence) + [\"<EOS>\"]\n",
    "    sentence = [char2int[c] for c in sentence]\n",
    "    s = s0\n",
    "    loss = []\n",
    "    for char,next_char in zip(sentence,sentence[1:]):\n",
    "        s = s.add_input(lookup[char])\n",
    "        probs = dy.softmax(R*s.output() + bias)\n",
    "        loss.append( -dy.log(dy.pick(probs,next_char)) )\n",
    "    loss = dy.esum(loss)\n",
    "    return loss\n",
    " \n",
    "\n",
    "# generate from model:\n",
    "def generate(rnn, params):\n",
    "    def sample(probs):\n",
    "        rnd = random.random()\n",
    "        for i,p in enumerate(probs):\n",
    "            rnd -= p\n",
    "            if rnd <= 0: break\n",
    "        return i\n",
    "    \n",
    "    # setup the sentence\n",
    "    dy.renew_cg()\n",
    "    s0 = rnn.initial_state()\n",
    "    \n",
    "    R = params[\"R\"]\n",
    "    bias = params[\"bias\"]\n",
    "    lookup = params[\"lookup\"]\n",
    "    \n",
    "    s = s0.add_input(lookup[char2int[\"<EOS>\"]])\n",
    "    out=[]\n",
    "    while True:\n",
    "        probs = dy.softmax(R*s.output() + bias)\n",
    "        probs = probs.vec_value()\n",
    "        next_char = sample(probs)\n",
    "        out.append(int2char[next_char])\n",
    "        if out[-1] == \"<EOS>\": break\n",
    "        s = s.add_input(lookup[next_char])\n",
    "    return \"\".join(out[:-1]) # strip the <EOS>\n",
    "        \n",
    "\n",
    "# train, and generate every 5 samples\n",
    "def train(rnn, params, sentence):\n",
    "    trainer = dy.SimpleSGDTrainer(pc)\n",
    "    for i in range(200):\n",
    "        loss = do_one_sentence(rnn, params, sentence)\n",
    "        loss_value = loss.value()\n",
    "        loss.backward()\n",
    "        trainer.update()\n",
    "        if i % 5 == 0: \n",
    "            print(\"%.10f\" % loss_value, end=\"\\t\")\n",
    "            print(generate(rnn, params))\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that:\n",
    "1. We pass the same rnn-builder to `do_one_sentence` over and over again.\n",
    "We must re-use the same rnn-builder, as this is where the shared parameters are kept.\n",
    "2. We `dy.renew_cg()` before each sentence -- because we want to have a new graph (new network) for this sentence.\n",
    "The parameters will be shared through the model and the shared rnn-builder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "149.1241455078\tczdz\n",
      "106.8910675049\tvvaleamkxnmgiehjqziydnqw wvdyqvocv\n",
      "70.4573440552\t\n",
      "45.8460693359\tk upcel romw ufec dbo tzhw\n",
      "23.2288417816\tp g\n",
      "10.3553504944\td slick boxwn xox\n",
      "3.4879603386\ta quick browncfox jumped oher dhe lazy dog\n",
      "1.2615398169\ta quick brown fox jumped over the lazy dog\n",
      "0.8019771576\tq quick brown fox jumped kver the lazy dog\n",
      "0.5865874290\ta quick brown fox jumped over the lazy dog\n",
      "0.4615180790\ta quick brown fox jumped over the lazy dog\n",
      "0.3798972070\ta quick brown fox jumpgd over the lazy dog\n",
      "0.3224956095\ta quick brown fox jumped over the lazy dog\n",
      "0.2799601257\ta quick brown fox jumped ovec the lazy dog\n",
      "0.2471984029\ta quick brown fox jumped over the lazy dog\n",
      "0.2212039977\ta quick brown fox jumped over ehe lazy dog\n",
      "0.2000847459\ta quick brown fox jumped over the lazy dog\n",
      "0.1825926602\ta quick brown fox jumped over the lazy dog\n",
      "0.1678716689\ta quick brown fox jumped over the lazy dog\n",
      "0.1553139985\ta quick brown fox jumped over the lazy dog\n",
      "0.1444785446\ta quick brown fox jumped over the lazy dog\n",
      "0.1350349337\ta quick brewn fox jumped over the lazy dog\n",
      "0.1267340034\ta quick brown fox jumped over the lazy dog\n",
      "0.1193781197\ta quick brown fox jumped over the lazy dog\n",
      "0.1128196493\ta quick brown fox jumped over the lazy dog\n",
      "0.1069323123\ta quick brown fox jumped over the lazy dog\n",
      "0.1016217321\ta quick brown fox jumped over the lazy dog\n",
      "0.0968040079\ta quick brown fox jumped over the lazy dog\n",
      "0.0924172848\ta quick brown fox jumped over the lazy dog\n",
      "0.0884065777\ta quick brown fox jumped over the lazy dog\n",
      "0.0847217217\ta quick brown fox jumped over the lazy dog\n",
      "0.0813286975\ta quick brown fox jumped over the lazy dog\n",
      "0.0781916752\ta quick brown fox vumped over the lazy dog\n",
      "0.0752858222\ta quick brown fox jumped over the lazy dog\n",
      "0.0725848153\ta quick brown fox jumped over the lazy dog\n",
      "0.0700684935\ta quick brown fox jumped over the lazy dog\n",
      "0.0677180886\ta quick brown fox jumped over the lazy dog\n",
      "0.0655169189\ta quick brown fox jumped over the lazy dog\n",
      "0.0634531230\ta quick brown fox jumped over the lazy dog\n",
      "0.0615142360\ta quick brown fox jumped over the lazy dog\n"
     ]
    }
   ],
   "source": [
    "sentence = \"a quick brown fox jumped over the lazy dog\"\n",
    "train(srnn, params_srnn, sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "143.8308105469\trsfmqxt ozqsiiaqa\n",
      "128.6366729736\tv ypvoobknwepfeply\n",
      "121.5253829956\tymquqr wfmocoe  ovwuwfmhdm ueod yewe\n",
      "115.6775283813\tq thluk  nwzwz eoodzod\n",
      "109.5644912720\tkxni bpuch xj enu mr ung omj dp eevem r vyyd t p lt  oyqbr\n",
      "102.8272857666\tbqkbhfmnb o mppoeoegoyt ddl rusay l  da\n",
      "98.4713058472\tqosqhn poafr  of uhexedoo pe h etavopyd pyiy d o yee al slghh\n",
      "90.9259567261\ta a qakmbn bm qcoayper efoyeroldpddm\n",
      "85.2885818481\tjku u upoowwbj jvdemmdfdeduree ood  oogdogpq dlto  y agzog i g gdlzac fokn  ux po opu uvlrr e eer rae ed ogy oel olzz\n",
      "76.7979202271\t\n",
      "71.5208969116\tq bauqn kkowcon ffojjpfp ox ouvtt e lzuy dv hoty dgggo oqjgkgo  oonx oxm om vee  eeo o ad\n",
      "62.6548461914\tdc zbrb oqn xomper joehpee eztlazd lqau\n",
      "56.1585731506\thowunqm oofw ojpder re rezt ogavyy  dogdcwo \n",
      "49.3944778442\tt bc qouwr rw o bo xm ojumer r ree ele azlad do\n",
      "41.5289344788\th uuikr ob wox jumepd rr loy ulz do\n",
      "36.3642997742\tdw ucown oox mx op jadmee ther loh\n",
      "30.3189773560\ta qqucakb onw fn jumjee oe tee taza dzo\n",
      "24.9423580170\ta quackk bborn fox juumpedde ove verr azy dogg\n",
      "19.8645935059\t aauk qbrrr oorf fmomed oee the layy oys\n",
      "15.7765054703\t aukc irr brow fox jumedd oveer aauzy dog\n",
      "12.4693098068\ta quikc brown foo jope dve ovr lay dogy dbog \n",
      "9.8480806351\tucc brown fox juxmpe ooer the llazy dog\n",
      "7.8634152412\ta quiic brronn fox jumed over the lazy dog\n",
      "5.9515495300\ta quick brown x jxx jjumed over the lzyy log\n",
      "4.5509667397\ta quic bronn fox jumpdd over the lazy dog\n",
      "3.4462499619\ta quic brown fox jumped oper the lazy dog\n",
      "2.2497565746\ta qgiick brown fox jumped over the lazy dog\n",
      "1.7854881287\ta quicr brown fox jumped over the lazy dog\n",
      "1.4716305733\ta quick brown n fumpped over the lazy dog\n",
      "1.2449830770\ta quick brown fox jumped over th lazy dog\n",
      "1.0744248629\ta quick brown fox jumped over the lazy dog\n",
      "0.9419770241\ta qick brown fox jumped over the lazy dog\n",
      "0.8365104198\ta quik brown fox jumped over the lazy dog\n",
      "0.7507840395\ta quick brown fox jumped over the lazy dog\n",
      "0.6798918843\ta quick brown fox jumped over the lazy dog\n",
      "0.6204041839\ta quick brown fox jumped over the lazy dog\n",
      "0.5698559284\ta quick brown fox jumped over the lazy dog\n",
      "0.5264287591\ta quick brown fox jumped over the lazy ddog\n",
      "0.4887620807\ta quick brown fox jumped over the lazy dog\n",
      "0.4558122456\ta quick brown fomx jumped over the lazy dog\n"
     ]
    }
   ],
   "source": [
    "sentence = \"a quick brown fox jumped over the lazy dog\"\n",
    "train(lstm, params_lstm, sentence)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model seem to learn the sentence quite well.\n",
    "\n",
    "Somewhat surprisingly, the Simple-RNN model learn quicker than the LSTM! (If you increase the number of layers, this difference will be even more pronounced)\n",
    "\n",
    "How can that be?\n",
    "\n",
    "The answer is that we are cheating a bit. The sentence we are trying to learn\n",
    "has each letter-bigram exactly once. This means a simple trigram model can memorize\n",
    "it very well.\n",
    "\n",
    "Try it out with more complex sequences.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "239.8676452637\ta quick brown fox jumped over the lazy dog\n",
      "160.0521697998\ta quick lrojny la tumped over the lazy driwn lrlazthe lazmzover the lazy dvgrn zrldnthd qheed gex thed bmne the lazy dbqre lazd yrd nrkwn fox jumped over the lazy dhe lazy dog\n",
      "95.4504623413\ta quick brzy e lazy e greaze erg\n",
      "63.7105903625\tahqne lretze n mve lbmed tizd mare a\n",
      "40.0441474915\taaquicpr the laz\n",
      "21.2409267426\taze ecpretztli arirltzixtoped bt\n",
      "7.8475003242\tthese pritzels pre maling me thirsty\n",
      "2.9435257912\tthese pretz ls whe taking me thirsty\n",
      "1.0504565239\tthese pretzels are making me thirsty\n",
      "0.4977459908\tthese pretzels are making me the mty\n",
      "0.3605296910\tthese pretz\n",
      "0.2839964628\tthese pretzels are making me ghirsty\n",
      "0.2345837206\tthese pretzels are making me thirsty\n",
      "0.1999202222\tthese pretzels are making me thirsty\n",
      "0.1742218733\tthese pretzels are making me thirsty\n",
      "0.1543948352\tthese pretzels are making me thirsty\n",
      "0.1386296600\tthese pretzels are making me thirsty\n",
      "0.1257905215\tthese pretzels are making me thirsty\n",
      "0.1151323095\tthese pretzels are making me thirsty\n",
      "0.1061392874\ttmese pretzels are making me thirsty\n",
      "0.0984523371\tthese pretzels are making me thirsty\n",
      "0.0918025821\tthese pretzels are making me thirsty\n",
      "0.0859967992\tthese pretzels are making me thirsty\n",
      "0.0808787867\tthese pretzels are making me thirsty\n",
      "0.0763375387\tthese pretzels are making me thirsty\n",
      "0.0722793266\tthese pretzels are making me thirsty\n",
      "0.0686300844\tthese pretzels are making me thirsty\n",
      "0.0653314814\tthese pretzels are making me thirsty\n",
      "0.0623353273\tthese pretzels are making me thirsty\n",
      "0.0596007779\tthese pretzels are making me thirsty\n",
      "0.0570969619\tthese pretzels are making me thirsty\n",
      "0.0547946990\tthese pretzels are making me thirsty\n",
      "0.0526688434\tthese pretzels are making me thirsty\n",
      "0.0507033207\tthese pretzels are making me thirsty\n",
      "0.0488782115\tthese pretzels are making me thirsty\n",
      "0.0471798219\tthese pretzels are making me thirsty\n",
      "0.0455954187\tthese pretzels are making me thirsty\n",
      "0.0441129319\tthese pretzels are making me thirsty\n",
      "0.0427242145\tthese pretzels are making me thirsty\n",
      "0.0414196625\tthese pretzels are making me thirsty\n"
     ]
    }
   ],
   "source": [
    "train(srnn, params_srnn, \"these pretzels are making me thirsty\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
